import litellm
import os
from litellm.integrations.custom_logger import CustomLogger
from jinja2 import Environment, FileSystemLoader, select_autoescape
from litellm.proxy.proxy_server import UserAPIKeyAuth, DualCache, proxy_config
from litellm.types.utils import ModelResponse,ChatCompletionMessageToolCall
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import Cache
from typing import Optional, Literal
class ToolHandler(CustomLogger):

    def __init__(self, message_logging = True):
        super().__init__(message_logging)
        # Set up Jinja2 template environment
        template_path = os.path.join(os.path.dirname(__file__), 'templates')
        self.env = Environment(
            loader=FileSystemLoader(template_path),
            autoescape=select_autoescape()
        )

        self.function_cache = Cache(type="local", default_in_memory_ttl=180.0)
    
    def support_proxy(self, data: dict) ->bool:
        model_list = proxy_config.config['model_list']
        for model in model_list:
            if model['model_name'] == data['model']:
                if 'use_proxy' in model['model_info']:
                    return model['model_info']['use_proxy']
        return False

    async def async_pre_call_hook(self, 
        user_api_key_dict: UserAPIKeyAuth,
        cache: DualCache, 
        data: dict,
        call_type: Literal[
            "completion",
            "text_completion",
            "embeddings",
            "image_generation",
            "moderation",
            "audio_transcription",
        ]):
        if not self.support_proxy(data) or call_type != "completion":
            return data
        verbose_proxy_logger.debug(f"use function proxy to hook data: {data}")
        # Check if tools are present
        if 'tools' in data and isinstance(data['tools'], list) and len(data['tools']) > 0:
            # Render template
            template = self.env.get_template('prompt_template.j2')
            rendered_prompt = template.render(
                tools=data['tools']
            )
            # Add tools prompt to system message
            if 'messages' in data:
                for message in data['messages']:
                    if message.get('role') == 'system':
                        message['content'] = f"{message['content']}\n\n{rendered_prompt}"
            data["tools"] = None
        # test role for ai tools
        if 'messages' in data:
            for message in data['messages']:
                if 'tool_calls' in message:
                    del message['tool_calls']
                if message.get('role') == 'tool':
                    id = message.get('tool_call_id')
                    if id is not None:
                        tool_call = await self.function_cache.async_get_cache(cache_key=id)
                        if tool_call is not None and isinstance(tool_call, dict):
                            function_name = tool_call['function']['name']
                            message['role'] = 'user'
                            content = message['content']
                            message['content'] = f"Function: {function_name}, Result: {content}"
                            del message['tool_call_id']
        return data
    
    async def async_post_call_success_hook(
        self,
        data: dict,
        user_api_key_dict: UserAPIKeyAuth,
        response: ModelResponse,
    ):
        if not self.support_proxy(data):
            return
        verbose_proxy_logger.debug(f"use function proxy to hook response: {response}")
        choice = response.choices[0]
        content = choice.message.content
        tool_calls = []
        try:
            json_blocks = content.split('```json')
            for block in json_blocks[1:]:
                json_str = block.split('```')[0].strip()
                import json
                json_data = json.loads(json_str)
                if isinstance(json_data, list):
                    for item in json_data:
                        if 'name' in item and 'arguments' in item:
                            tool_call = ChatCompletionMessageToolCall(function=item)
                            tool_calls.append(tool_call)
                            key = tool_call.id
                            await self.function_cache.async_add_cache(tool_call, cache_key=key)
                else:
                    if 'name' in json_data and 'arguments' in json_data:
                        tool_call = ChatCompletionMessageToolCall(function=json_data)
                        tool_calls.append(tool_call)
                        key = tool_call.id
                        await self.function_cache.async_add_cache(tool_call, cache_key=key)
            
            if len(tool_calls)>0:
                choice.finish_reason = 'tool_calls'
                choice.message.tool_calls = tool_calls
                response.choices = [choice]
        except Exception as e:
            verbose_proxy_logger.error(f"Error parsing JSON Markdown: {e}") 


            

tool_call_handler_instance = ToolHandler()